package jwt

import (
	"fmt"
	"log"
	"strings"
	"time"

	"gitee.com/chenhonghua/ginorigin/http/restful"
	jwtpkg "github.com/dgrijalva/jwt-go"
	"github.com/gin-gonic/gin"
	"github.com/patrickmn/go-cache"
	"golang.org/x/sync/singleflight"
)

var (
	jwtConfig JWTConfig

	TOKEN_CACHE        *cache.Cache  = cache.New(time.Duration(0), time.Duration(0))
	TOKEN_CACHE_EXPIRE time.Duration = time.Hour * 24

	concurrencyControl = &singleflight.Group{} // 单线程操作，防止缓存击穿
	signKey            = []byte("jwt")
)

type JWTConfig struct {
	Enable             bool   `mapstructure:"enable" json:"enable" yaml:"enable"`                                         // 是否开启
	TokenHeader        string `mapstructure:"token-header" json:"tokenHeader" yaml:"token-header"`                        // 请求中的header名
	NewTokenHeader     string `mapstructure:"new-token-header" json:"newTokenHeader" yaml:"new-token-header"`             // 响应新token的header名
	SignKey            string `mapstructure:"sign-key" json:"signKey" yaml:"sign-key"`                                    // jwt签名
	ExpireTime         int64  `mapstructure:"expire-time" json:"expireTime" yaml:"expire-time"`                           // 过期时间
	RefreshIntervalMin int64  `mapstructure:"refresh-interval-min" json:"refreshIntervalMin" yaml:"refresh-interval-min"` // 刷新token的最小间隔时间
	Issuer             string `mapstructure:"issuer" json:"issuer" yaml:"issuer"`                                         // 签发者
	HijackCheckEnable  bool   `mapstructure:"hijack-check-enable" json:"hijackCheckEnable" yaml:"hijack-check-enable"`    // 是否开启劫持检测
}

func isEnables() bool {
	enable := jwtConfig.Enable && len(jwtConfig.SignKey) > 0
	if !enable {
		log.Printf("未启用JWT:%v\n", jwtConfig)
	}
	return enable
}

func (j JWTConfig) Load() {
	jwtConfig = j
	if !isEnables() {
		return
	}
	signKey = []byte(jwtConfig.SignKey)
	TOKEN_CACHE_EXPIRE = time.Duration(jwtConfig.ExpireTime) * time.Second
	log.Println("启用jwt模块，开始进行配置")
}

// 公共jwt认证处理
var JWT_HANDLER gin.HandlerFunc = func(ctx *gin.Context) {
	ts := ctx.Request.Header.Get(jwtConfig.TokenHeader)
	// log.Debugf("开始进行jwt请求预处理:token=%s\n", ts)
	if len(ts) == 0 {
		restful.Errors.FromError(ErrNotLogin).Restful(ctx)
		ctx.Abort()
		return
	}
	claims, e := ParseTokenString(ts)
	if e != nil {
		restful.Errors.FromError(e).Restful(ctx)
		ctx.Abort()
		return
	}
	if jwtConfig.HijackCheckEnable {
		cacheToken, err := GetToken(claims.Id)
		if err != nil || !strings.EqualFold(cacheToken, ts) { // 缓存过期的token(maybe绕过了jwt的检查)、伪造的有效token、劫持的token
			restful.Errors.FromError(ErrTokenExpired).Restful(ctx)
			ctx.Abort()
			return
		}
	}
	// 验证通过
	n := time.Now().Unix()
	lastIssuedAt := claims.ExpiresAt - jwtConfig.ExpireTime // 此token的真实签发时间
	aliveTime := n - lastIssuedAt                           // 此token已使用的时间（秒）
	if aliveTime > jwtConfig.RefreshIntervalMin {           // token有效，但已大于刷新最小间隔，刷新token
		claims.ExpiresAt = n + jwtConfig.ExpireTime // 更新token的过期时间
		nt := jwtpkg.NewWithClaims(jwtpkg.SigningMethodHS256, claims)
		newTokenString, _ := nt.SignedString(signKey)        // 签名token实例，生成tokenString
		ctx.Header(jwtConfig.NewTokenHeader, newTokenString) // 通知客户端更换新token
		if jwtConfig.HijackCheckEnable {
			SetToken(claims.Id, newTokenString)
		}
	}
	SetClaims(ctx, claims) // 将claims信息绑定在gin.context上，减少后续解析
	ctx.Next()
}

// 缓存键
func jwt_cacheKey(id string) string { return fmt.Sprintf("TOKEN_CACHE_KEY_PREFIX::%s", id) }

// 从缓存获取会话token
func GetToken(id string) (string, error) { // 传入参数可以自行定义
	if r, b := TOKEN_CACHE.Get(jwt_cacheKey(id)); b {
		return r.(string), nil
	}
	return "", ErrTokenNotFound
}

func SetToken(id string, token string) { // 传入参数可以自行定义
	TOKEN_CACHE.Set(jwt_cacheKey(id), token, TOKEN_CACHE_EXPIRE)
}

func RemoveToken(id string) {
	TOKEN_CACHE.Delete(jwt_cacheKey(id))
}

func NewToken(id string) (ts string, err error) {
	if !isEnables() {
		return "", ErrJwtNotActive
	}
	n := time.Now().Unix() // 当前时间秒值
	c := jwtpkg.StandardClaims{
		Subject:   jwtConfig.Issuer,         // 主题
		Issuer:    jwtConfig.Issuer,         // 签名的发行者
		IssuedAt:  n,                        // 初次签名时间
		NotBefore: n - jwtConfig.ExpireTime, // 签名生效时间，小于当前，则立刻生效
		ExpiresAt: n + jwtConfig.ExpireTime, // 过期时间
		Audience:  "",                       // 接收方
		Id:        id,                       // jwt的唯一身份标识，主要用来作为一次性token,从而回避重放攻击
	}
	// c := &BaseClaims{ // 这里的创建实例，需要加引用，因为jwtpkg.ParseWithClaims使用的是Claims接口，如果在构建实例时类型太明确，会导致tokenString反序列化成实例时失败
	// 	StandardClaims: sc,
	// 	UserId:         uid,
	// }
	t := jwtpkg.NewWithClaims(jwtpkg.SigningMethodHS256, c)
	if nil == t {
		return "", ErrTokenGenFailed
	}
	ts, _ = t.SignedString(signKey) // 签名token实例，生成tokenString
	if len(ts) == 0 {
		return "", ErrTokenSignFailed
	}
	SetToken(id, ts)
	return ts, err
}

// 解析token，获取载荷
func ParseTokenString(tokenString string) (*jwtpkg.StandardClaims, error) {
	v, err, _ := concurrencyControl.Do("parse:"+tokenString, func() (interface{}, error) {
		return jwtpkg.ParseWithClaims(tokenString, &jwtpkg.StandardClaims{}, func(token *jwtpkg.Token) (i interface{}, e error) {
			return signKey, nil
		})
	})
	if err == nil && v != nil {
		t := v.(*jwtpkg.Token)
		if t.Valid { // 解析是否有效、合法的token
			if claims, ok := t.Claims.(*jwtpkg.StandardClaims); ok {
				return claims, nil
			}
		}
	}
	if err != nil {
		if ve, ok := err.(*jwtpkg.ValidationError); ok {
			if ve.Errors&jwtpkg.ValidationErrorMalformed != 0 { // token格式错误
				return nil, ErrTokenMalformed
			} else if ve.Errors&jwtpkg.ValidationErrorExpired != 0 { // token过期
				return nil, ErrTokenExpired
			} else if ve.Errors&jwtpkg.ValidationErrorNotValidYet != 0 { // 验证错误
				return nil, ErrTokenNotValidYet
			}
		}
	}
	return nil, ErrTokenInvalid
}

func GetClaims(ctx *gin.Context) *jwtpkg.StandardClaims {
	v, ok := ctx.Get("CONTEXT_CACHEKEY_CLAIMS")
	if !ok {
		return nil
	}
	return v.(*jwtpkg.StandardClaims)
}

func SetClaims(ctx *gin.Context, claims *jwtpkg.StandardClaims) {
	ctx.Set("CONTEXT_CACHEKEY_CLAIMS", claims)
}
